{
 "cells": [
  {
   "cell_type": "raw",
   "id": "df7d42b9-58a6-434c-a2d7-0b61142f6d3e",
   "metadata": {},
   "source": [
    "---\n",
    "sidebar_position: 7\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2195672-0cab-4967-ba8a-c6544635547d",
   "metadata": {},
   "source": [
    "# Deal with High Cardinality Categoricals\n",
    "\n",
    "You may want to do query analysis to create a filter on a categorical column. One of the difficulties here is that you usually need to specify the EXACT categorical value. The issue is you need to make sure the LLM generates that categorical value exactly. This can be done relatively easy with prompting when there are only a few values that are valid. When there are a high number of valid values then it becomes more difficult, as those values may not fit in the LLM context, or (if they do) there may be too many for the LLM to properly attend to.\n",
    "\n",
    "In this notebook we take a look at how to approach this."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4079b57-4369-49c9-b2ad-c809b5408d7e",
   "metadata": {},
   "source": [
    "## Setup\n",
    "#### Install dependencies\n",
    "\n",
    "```{=mdx}\n",
    "import IntegrationInstallTooltip from \"@mdx_components/integration_install_tooltip.mdx\";\n",
    "import Npm2Yarn from \"@theme/Npm2Yarn\";\n",
    "\n",
    "<IntegrationInstallTooltip></IntegrationInstallTooltip>\n",
    "\n",
    "<Npm2Yarn>\n",
    "  @langchain/core @langchain/community @langchain/openai zod chromadb @faker-js/faker\n",
    "</Npm2Yarn>\n",
    "```\n",
    "\n",
    "#### Set environment variables\n",
    "\n",
    "We'll use OpenAI in this example:\n",
    "\n",
    "```\n",
    "OPENAI_API_KEY=your-api-key\n",
    "\n",
    "# Optional, use LangSmith for best-in-class observability\n",
    "LANGSMITH_API_KEY=your-api-key\n",
    "LANGCHAIN_TRACING_V2=true\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8d47f4b",
   "metadata": {},
   "source": [
    "#### Set up data\n",
    "\n",
    "We will generate a bunch of fake names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "e5ba65c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import { faker } from \"@faker-js/faker\";\n",
    "\n",
    "const names = Array.from({ length: 10000 }, () => faker.person.fullName());"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41133694",
   "metadata": {},
   "source": [
    "Let's look at some of the names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "c901ea97",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[32m\"Dale Kessler\"\u001b[39m"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "names[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b0d42ae2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[32m\"Mrs. Chelsea Bayer MD\"\u001b[39m"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "names[567]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1725883d",
   "metadata": {},
   "source": [
    "## Query Analysis\n",
    "\n",
    "We can now set up a baseline query analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "6c9485ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import { z } from \"zod\";\n",
    "\n",
    "const searchSchema = z.object({\n",
    "    query: z.string(),\n",
    "    author: z.string(),\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "aebd704a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import { ChatPromptTemplate } from \"@langchain/core/prompts\";\n",
    "import { RunnablePassthrough, RunnableSequence } from \"@langchain/core/runnables\";\n",
    "import { ChatOpenAI } from \"@langchain/openai\";\n",
    "\n",
    "const system = `Generate a relevant search query for a library system`;\n",
    "const prompt = ChatPromptTemplate.fromMessages(\n",
    "    [\n",
    "      [\"system\", system],\n",
    "      [\"human\", \"{question}\"],\n",
    "    ]\n",
    ")\n",
    "const llm = new ChatOpenAI({\n",
    "  modelName: \"gpt-3.5-turbo-0125\",\n",
    "  temperature: 0\n",
    "});\n",
    "const llmWithTools = llm.withStructuredOutput(searchSchema, {\n",
    "  name: \"Search\"\n",
    "})\n",
    "const queryAnalyzer = RunnableSequence.from([\n",
    "  {\n",
    "    question: new RunnablePassthrough(),\n",
    "  },\n",
    "  prompt,\n",
    "  llmWithTools\n",
    "]);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41709a2e",
   "metadata": {},
   "source": [
    "We can see that if we spell the name exactly correctly, it knows how to handle it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "cc0d344b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ query: \u001b[32m\"books about aliens\"\u001b[39m, author: \u001b[32m\"Jesse Knight\"\u001b[39m }"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await queryAnalyzer.invoke(\"what are books about aliens by Jesse Knight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b57eab",
   "metadata": {},
   "source": [
    "The issue is that the values you want to filter on may NOT be spelled exactly correctly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "82b6b2ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ query: \u001b[32m\"books about aliens\"\u001b[39m, author: \u001b[32m\"Jess Knight\"\u001b[39m }"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await queryAnalyzer.invoke(\"what are books about aliens by jess knight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b60b7c2",
   "metadata": {},
   "source": [
    "### Add in all values\n",
    "\n",
    "One way around this is to add ALL possible values to the prompt. That will generally guide the query in the right direction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "98788a94",
   "metadata": {},
   "outputs": [],
   "source": [
    "const system = `Generate a relevant search query for a library system using the 'search' tool.\n",
    "\n",
    "The 'author' you return to the user MUST be one of the following authors:\n",
    "\n",
    "{authors}\n",
    "\n",
    "Do NOT hallucinate author name!`\n",
    "const basePrompt = ChatPromptTemplate.fromMessages(\n",
    "    [\n",
    "      [\"system\", system],\n",
    "      [\"human\", \"{question}\"],\n",
    "    ]\n",
    ")\n",
    "const prompt = await basePrompt.partial({ authors: names.join(\", \") })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "e65412f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "const queryAnalyzerAll = RunnableSequence.from([\n",
    "  {\n",
    "    question: new RunnablePassthrough(),\n",
    "  },\n",
    "  prompt,\n",
    "  llmWithTools\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e639285a",
   "metadata": {},
   "source": [
    "However... if the list of categoricals is long enough, it may error!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "696b000f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Error: 400 This model's maximum context length is 16385 tokens. However, your messages resulted in 49822 tokens (49792 in the messages, 30 in the functions). Please reduce the length of the messages or functions.\n",
      "    at Function.generate (file:///Users/bracesproul/Library/Caches/deno/npm/registry.npmjs.org/openai/4.28.4/error.mjs:40:20)\n",
      "    at OpenAI.makeStatusError (file:///Users/bracesproul/Library/Caches/deno/npm/registry.npmjs.org/openai/4.28.4/core.mjs:256:25)\n",
      "    at OpenAI.makeRequest (file:///Users/bracesproul/Library/Caches/deno/npm/registry.npmjs.org/openai/4.28.4/core.mjs:299:30)\n",
      "    at eventLoopTick (ext:core/01_core.js:63:7)\n",
      "    at async file:///Users/bracesproul/Library/Caches/deno/npm/registry.npmjs.org/@langchain/openai/0.0.15/dist/chat_models.js:650:29\n",
      "    at async RetryOperation._fn (file:///Users/bracesproul/Library/Caches/deno/npm/registry.npmjs.org/p-retry/4.6.2/index.js:50:12) {\n",
      "  status: 400,\n",
      "  headers: {\n",
      "    \"access-control-allow-origin\": \"*\",\n",
      "    \"alt-svc\": 'h3=\":443\"; ma=86400',\n",
      "    \"cf-cache-status\": \"DYNAMIC\",\n",
      "    \"cf-ray\": \"85f6e713581815d0-SJC\",\n",
      "    \"content-length\": \"341\",\n",
      "    \"content-type\": \"application/json\",\n",
      "    date: \"Tue, 05 Mar 2024 03:08:39 GMT\",\n",
      "    \"openai-organization\": \"langchain\",\n",
      "    \"openai-processing-ms\": \"349\",\n",
      "    \"openai-version\": \"2020-10-01\",\n",
      "    server: \"cloudflare\",\n",
      "    \"set-cookie\": \"_cfuvid=NXe7nstRj6UNdFs5F8k49JZF6Tz7EE8dfKwYRpV3AWI-1709608119946-0.0.1.1-604800000; path=/; domain=\"... 48 more characters,\n",
      "    \"strict-transport-security\": \"max-age=15724800; includeSubDomains\",\n",
      "    \"x-ratelimit-limit-requests\": \"10000\",\n",
      "    \"x-ratelimit-limit-tokens\": \"2000000\",\n",
      "    \"x-ratelimit-remaining-requests\": \"9999\",\n",
      "    \"x-ratelimit-remaining-tokens\": \"1958537\",\n",
      "    \"x-ratelimit-reset-requests\": \"6ms\",\n",
      "    \"x-ratelimit-reset-tokens\": \"1.243s\",\n",
      "    \"x-request-id\": \"req_99890749d442033c6145f9a8f1324aea\"\n",
      "  },\n",
      "  error: {\n",
      "    message: \"This model's maximum context length is 16385 tokens. However, your messages resulted in 49822 tokens\"... 101 more characters,\n",
      "    type: \"invalid_request_error\",\n",
      "    param: \"messages\",\n",
      "    code: \"context_length_exceeded\"\n",
      "  },\n",
      "  code: \"context_length_exceeded\",\n",
      "  param: \"messages\",\n",
      "  type: \"invalid_request_error\",\n",
      "  attemptNumber: 1,\n",
      "  retriesLeft: 6\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "try {\n",
    "    const res = await queryAnalyzerAll.invoke(\"what are books about aliens by jess knight\")\n",
    "} catch (e) {\n",
    "    console.error(e)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d5d7891",
   "metadata": {},
   "source": [
    "We can try to use a longer context window... but with so much information in there, it is not garunteed to pick it up reliably"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "0f0d0757",
   "metadata": {},
   "outputs": [],
   "source": [
    "const llmLong = new ChatOpenAI({\n",
    "  modelName: \"gpt-4-turbo-preview\",\n",
    "  temperature: 0,\n",
    "})\n",
    "const structuredLlmLong = llmLong.withStructuredOutput(searchSchema, {\n",
    "  name: \"Search\"\n",
    "});\n",
    "const queryAnalyzerAll = RunnableSequence.from([\n",
    "  {\n",
    "    question: new RunnablePassthrough(),\n",
    "  },\n",
    "  prompt,\n",
    "  structuredLlmLong\n",
    "]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "03e5b7b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ query: \u001b[32m\"aliens\"\u001b[39m, author: \u001b[32m\"Jess Knight\"\u001b[39m }"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await queryAnalyzerAll.invoke(\"what are books about aliens by jess knight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73ecf52b",
   "metadata": {},
   "source": [
    "### Find and all relevant values\n",
    "\n",
    "Instead, what we can do is create an index over the relevant values and then query that for the N most relevant values,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "32b19e07",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Module: null prototype] {\n",
       "  AdminClient: \u001b[36m[class AdminClient]\u001b[39m,\n",
       "  ChromaClient: \u001b[36m[class ChromaClient]\u001b[39m,\n",
       "  CloudClient: \u001b[36m[class CloudClient extends ChromaClient]\u001b[39m,\n",
       "  CohereEmbeddingFunction: \u001b[36m[class CohereEmbeddingFunction]\u001b[39m,\n",
       "  Collection: \u001b[36m[class Collection]\u001b[39m,\n",
       "  DefaultEmbeddingFunction: \u001b[36m[class _DefaultEmbeddingFunction]\u001b[39m,\n",
       "  GoogleGenerativeAiEmbeddingFunction: \u001b[36m[class _GoogleGenerativeAiEmbeddingFunction]\u001b[39m,\n",
       "  HuggingFaceEmbeddingServerFunction: \u001b[36m[class HuggingFaceEmbeddingServerFunction]\u001b[39m,\n",
       "  IncludeEnum: {\n",
       "    Documents: \u001b[32m\"documents\"\u001b[39m,\n",
       "    Embeddings: \u001b[32m\"embeddings\"\u001b[39m,\n",
       "    Metadatas: \u001b[32m\"metadatas\"\u001b[39m,\n",
       "    Distances: \u001b[32m\"distances\"\u001b[39m\n",
       "  },\n",
       "  JinaEmbeddingFunction: \u001b[36m[class JinaEmbeddingFunction]\u001b[39m,\n",
       "  OpenAIEmbeddingFunction: \u001b[36m[class _OpenAIEmbeddingFunction]\u001b[39m,\n",
       "  TransformersEmbeddingFunction: \u001b[36m[class _TransformersEmbeddingFunction]\u001b[39m\n",
       "}"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import { Chroma } from \"@langchain/community/vectorstores/chroma\";\n",
    "import { OpenAIEmbeddings } from \"@langchain/openai\";\n",
    "import \"chromadb\";\n",
    "\n",
    "const embeddings = new OpenAIEmbeddings({\n",
    "  modelName: \"text-embedding-3-small\",\n",
    "})\n",
    "const vectorstore = await Chroma.fromTexts(names, {}, embeddings, {\n",
    "  collectionName: \"author_names\"\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "774cb7b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "const selectNames = async (question: string) => {\n",
    "    const _docs = await vectorstore.similaritySearch(question, 10);\n",
    "    const _names = _docs.map(d => d.pageContent);\n",
    "    return _names.join(\", \");\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "1173159c",
   "metadata": {},
   "outputs": [],
   "source": [
    "const createPrompt = RunnableSequence.from([\n",
    "    {\n",
    "        question: new RunnablePassthrough(),\n",
    "        authors: selectNames,\n",
    "    },\n",
    "    basePrompt\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "0a892607",
   "metadata": {},
   "outputs": [],
   "source": [
    "const queryAnalyzerSelect = createPrompt.pipe(llmWithTools);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "8195d7cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ChatPromptValue {\n",
       "  lc_serializable: \u001b[33mtrue\u001b[39m,\n",
       "  lc_kwargs: {\n",
       "    messages: [\n",
       "      SystemMessage {\n",
       "        lc_serializable: \u001b[33mtrue\u001b[39m,\n",
       "        lc_kwargs: {\n",
       "          content: \u001b[32m\"Generate a relevant search query for a library system using the 'search' tool.\\n\"\u001b[39m +\n",
       "            \u001b[32m\"\\n\"\u001b[39m +\n",
       "            \u001b[32m\"The 'author' you ret\"\u001b[39m... 259 more characters,\n",
       "          additional_kwargs: {}\n",
       "        },\n",
       "        lc_namespace: [ \u001b[32m\"langchain_core\"\u001b[39m, \u001b[32m\"messages\"\u001b[39m ],\n",
       "        content: \u001b[32m\"Generate a relevant search query for a library system using the 'search' tool.\\n\"\u001b[39m +\n",
       "          \u001b[32m\"\\n\"\u001b[39m +\n",
       "          \u001b[32m\"The 'author' you ret\"\u001b[39m... 259 more characters,\n",
       "        name: \u001b[90mundefined\u001b[39m,\n",
       "        additional_kwargs: {}\n",
       "      },\n",
       "      HumanMessage {\n",
       "        lc_serializable: \u001b[33mtrue\u001b[39m,\n",
       "        lc_kwargs: {\n",
       "          content: \u001b[32m\"what are books by jess knight\"\u001b[39m,\n",
       "          additional_kwargs: {}\n",
       "        },\n",
       "        lc_namespace: [ \u001b[32m\"langchain_core\"\u001b[39m, \u001b[32m\"messages\"\u001b[39m ],\n",
       "        content: \u001b[32m\"what are books by jess knight\"\u001b[39m,\n",
       "        name: \u001b[90mundefined\u001b[39m,\n",
       "        additional_kwargs: {}\n",
       "      }\n",
       "    ]\n",
       "  },\n",
       "  lc_namespace: [ \u001b[32m\"langchain_core\"\u001b[39m, \u001b[32m\"prompt_values\"\u001b[39m ],\n",
       "  messages: [\n",
       "    SystemMessage {\n",
       "      lc_serializable: \u001b[33mtrue\u001b[39m,\n",
       "      lc_kwargs: {\n",
       "        content: \u001b[32m\"Generate a relevant search query for a library system using the 'search' tool.\\n\"\u001b[39m +\n",
       "          \u001b[32m\"\\n\"\u001b[39m +\n",
       "          \u001b[32m\"The 'author' you ret\"\u001b[39m... 259 more characters,\n",
       "        additional_kwargs: {}\n",
       "      },\n",
       "      lc_namespace: [ \u001b[32m\"langchain_core\"\u001b[39m, \u001b[32m\"messages\"\u001b[39m ],\n",
       "      content: \u001b[32m\"Generate a relevant search query for a library system using the 'search' tool.\\n\"\u001b[39m +\n",
       "        \u001b[32m\"\\n\"\u001b[39m +\n",
       "        \u001b[32m\"The 'author' you ret\"\u001b[39m... 259 more characters,\n",
       "      name: \u001b[90mundefined\u001b[39m,\n",
       "      additional_kwargs: {}\n",
       "    },\n",
       "    HumanMessage {\n",
       "      lc_serializable: \u001b[33mtrue\u001b[39m,\n",
       "      lc_kwargs: {\n",
       "        content: \u001b[32m\"what are books by jess knight\"\u001b[39m,\n",
       "        additional_kwargs: {}\n",
       "      },\n",
       "      lc_namespace: [ \u001b[32m\"langchain_core\"\u001b[39m, \u001b[32m\"messages\"\u001b[39m ],\n",
       "      content: \u001b[32m\"what are books by jess knight\"\u001b[39m,\n",
       "      name: \u001b[90mundefined\u001b[39m,\n",
       "      additional_kwargs: {}\n",
       "    }\n",
       "  ]\n",
       "}"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await createPrompt.invoke(\"what are books by jess knight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "d3228b4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ query: \u001b[32m\"books about aliens\"\u001b[39m, author: \u001b[32m\"Jessica Kerluke\"\u001b[39m }"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "await queryAnalyzerSelect.invoke(\"what are books about aliens by jess knight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Deno",
   "language": "typescript",
   "name": "deno"
  },
  "language_info": {
   "file_extension": ".ts",
   "mimetype": "text/x.typescript",
   "name": "typescript",
   "nb_converter": "script",
   "pygments_lexer": "typescript",
   "version": "5.3.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
